{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- School: string (nullable = true)\n",
      " |-- Private: string (nullable = true)\n",
      " |-- Apps: integer (nullable = true)\n",
      " |-- Accept: integer (nullable = true)\n",
      " |-- Enroll: integer (nullable = true)\n",
      " |-- Top10perc: integer (nullable = true)\n",
      " |-- Top25perc: integer (nullable = true)\n",
      " |-- F_Undergrad: integer (nullable = true)\n",
      " |-- P_Undergrad: integer (nullable = true)\n",
      " |-- Outstate: integer (nullable = true)\n",
      " |-- Room_Board: integer (nullable = true)\n",
      " |-- Books: integer (nullable = true)\n",
      " |-- Personal: integer (nullable = true)\n",
      " |-- PhD: integer (nullable = true)\n",
      " |-- Terminal: integer (nullable = true)\n",
      " |-- S_F_Ratio: double (nullable = true)\n",
      " |-- perc_alumni: integer (nullable = true)\n",
      " |-- Expend: integer (nullable = true)\n",
      " |-- Grad_Rate: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.appName('tree').getOrCreate()\n",
    "df = spark.read.csv('College.csv', inferSchema=True, header=True)\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>School</th>\n",
       "      <td>Abilene Christian University</td>\n",
       "      <td>Adelphi University</td>\n",
       "      <td>Adrian College</td>\n",
       "      <td>Agnes Scott College</td>\n",
       "      <td>Alaska Pacific University</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Private</th>\n",
       "      <td>Yes</td>\n",
       "      <td>Yes</td>\n",
       "      <td>Yes</td>\n",
       "      <td>Yes</td>\n",
       "      <td>Yes</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Apps</th>\n",
       "      <td>1660</td>\n",
       "      <td>2186</td>\n",
       "      <td>1428</td>\n",
       "      <td>417</td>\n",
       "      <td>193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Accept</th>\n",
       "      <td>1232</td>\n",
       "      <td>1924</td>\n",
       "      <td>1097</td>\n",
       "      <td>349</td>\n",
       "      <td>146</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Enroll</th>\n",
       "      <td>721</td>\n",
       "      <td>512</td>\n",
       "      <td>336</td>\n",
       "      <td>137</td>\n",
       "      <td>55</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Top10perc</th>\n",
       "      <td>23</td>\n",
       "      <td>16</td>\n",
       "      <td>22</td>\n",
       "      <td>60</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Top25perc</th>\n",
       "      <td>52</td>\n",
       "      <td>29</td>\n",
       "      <td>50</td>\n",
       "      <td>89</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F_Undergrad</th>\n",
       "      <td>2885</td>\n",
       "      <td>2683</td>\n",
       "      <td>1036</td>\n",
       "      <td>510</td>\n",
       "      <td>249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>P_Undergrad</th>\n",
       "      <td>537</td>\n",
       "      <td>1227</td>\n",
       "      <td>99</td>\n",
       "      <td>63</td>\n",
       "      <td>869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Outstate</th>\n",
       "      <td>7440</td>\n",
       "      <td>12280</td>\n",
       "      <td>11250</td>\n",
       "      <td>12960</td>\n",
       "      <td>7560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Room_Board</th>\n",
       "      <td>3300</td>\n",
       "      <td>6450</td>\n",
       "      <td>3750</td>\n",
       "      <td>5450</td>\n",
       "      <td>4120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Books</th>\n",
       "      <td>450</td>\n",
       "      <td>750</td>\n",
       "      <td>400</td>\n",
       "      <td>450</td>\n",
       "      <td>800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Personal</th>\n",
       "      <td>2200</td>\n",
       "      <td>1500</td>\n",
       "      <td>1165</td>\n",
       "      <td>875</td>\n",
       "      <td>1500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PhD</th>\n",
       "      <td>70</td>\n",
       "      <td>29</td>\n",
       "      <td>53</td>\n",
       "      <td>92</td>\n",
       "      <td>76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Terminal</th>\n",
       "      <td>78</td>\n",
       "      <td>30</td>\n",
       "      <td>66</td>\n",
       "      <td>97</td>\n",
       "      <td>72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>S_F_Ratio</th>\n",
       "      <td>18.1</td>\n",
       "      <td>12.2</td>\n",
       "      <td>12.9</td>\n",
       "      <td>7.7</td>\n",
       "      <td>11.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>perc_alumni</th>\n",
       "      <td>12</td>\n",
       "      <td>16</td>\n",
       "      <td>30</td>\n",
       "      <td>37</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Expend</th>\n",
       "      <td>7041</td>\n",
       "      <td>10527</td>\n",
       "      <td>8735</td>\n",
       "      <td>19016</td>\n",
       "      <td>10922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Grad_Rate</th>\n",
       "      <td>60</td>\n",
       "      <td>56</td>\n",
       "      <td>54</td>\n",
       "      <td>59</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                        0                   1               2  \\\n",
       "School       Abilene Christian University  Adelphi University  Adrian College   \n",
       "Private                               Yes                 Yes             Yes   \n",
       "Apps                                 1660                2186            1428   \n",
       "Accept                               1232                1924            1097   \n",
       "Enroll                                721                 512             336   \n",
       "Top10perc                              23                  16              22   \n",
       "Top25perc                              52                  29              50   \n",
       "F_Undergrad                          2885                2683            1036   \n",
       "P_Undergrad                           537                1227              99   \n",
       "Outstate                             7440               12280           11250   \n",
       "Room_Board                           3300                6450            3750   \n",
       "Books                                 450                 750             400   \n",
       "Personal                             2200                1500            1165   \n",
       "PhD                                    70                  29              53   \n",
       "Terminal                               78                  30              66   \n",
       "S_F_Ratio                            18.1                12.2            12.9   \n",
       "perc_alumni                            12                  16              30   \n",
       "Expend                               7041               10527            8735   \n",
       "Grad_Rate                              60                  56              54   \n",
       "\n",
       "                               3                          4  \n",
       "School       Agnes Scott College  Alaska Pacific University  \n",
       "Private                      Yes                        Yes  \n",
       "Apps                         417                        193  \n",
       "Accept                       349                        146  \n",
       "Enroll                       137                         55  \n",
       "Top10perc                     60                         16  \n",
       "Top25perc                     89                         44  \n",
       "F_Undergrad                  510                        249  \n",
       "P_Undergrad                   63                        869  \n",
       "Outstate                   12960                       7560  \n",
       "Room_Board                  5450                       4120  \n",
       "Books                        450                        800  \n",
       "Personal                     875                       1500  \n",
       "PhD                           92                         76  \n",
       "Terminal                      97                         72  \n",
       "S_F_Ratio                    7.7                       11.9  \n",
       "perc_alumni                   37                          2  \n",
       "Expend                     19016                      10922  \n",
       "Grad_Rate                     59                         15  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame(df.take(5), columns = df.columns).transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['School',\n",
       " 'Private',\n",
       " 'Apps',\n",
       " 'Accept',\n",
       " 'Enroll',\n",
       " 'Top10perc',\n",
       " 'Top25perc',\n",
       " 'F_Undergrad',\n",
       " 'P_Undergrad',\n",
       " 'Outstate',\n",
       " 'Room_Board',\n",
       " 'Books',\n",
       " 'Personal',\n",
       " 'PhD',\n",
       " 'Terminal',\n",
       " 'S_F_Ratio',\n",
       " 'perc_alumni',\n",
       " 'Expend',\n",
       " 'Grad_Rate']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "\n",
    "assembler = VectorAssembler(inputCols = ['Apps','Accept','Enroll','Top10perc','Top25perc','F_Undergrad','P_Undergrad','Outstate','Room_Board','Books','Personal','PhD','Terminal','S_F_Ratio','perc_alumni','Expend','Grad_Rate'], outputCol = 'features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = assembler.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- School: string (nullable = true)\n",
      " |-- Private: string (nullable = true)\n",
      " |-- Apps: integer (nullable = true)\n",
      " |-- Accept: integer (nullable = true)\n",
      " |-- Enroll: integer (nullable = true)\n",
      " |-- Top10perc: integer (nullable = true)\n",
      " |-- Top25perc: integer (nullable = true)\n",
      " |-- F_Undergrad: integer (nullable = true)\n",
      " |-- P_Undergrad: integer (nullable = true)\n",
      " |-- Outstate: integer (nullable = true)\n",
      " |-- Room_Board: integer (nullable = true)\n",
      " |-- Books: integer (nullable = true)\n",
      " |-- Personal: integer (nullable = true)\n",
      " |-- PhD: integer (nullable = true)\n",
      " |-- Terminal: integer (nullable = true)\n",
      " |-- S_F_Ratio: double (nullable = true)\n",
      " |-- perc_alumni: integer (nullable = true)\n",
      " |-- Expend: integer (nullable = true)\n",
      " |-- Grad_Rate: integer (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- PrivateIndex: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "\n",
    "indexer = StringIndexer(inputCol = 'Private', outputCol = 'PrivateIndex')\n",
    "outputFixed = indexer.fit(output).transform(output)\n",
    "outputFixed.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+------------+\n",
      "|            features|PrivateIndex|\n",
      "+--------------------+------------+\n",
      "|[1660.0,1232.0,72...|         0.0|\n",
      "|[2186.0,1924.0,51...|         0.0|\n",
      "|[1428.0,1097.0,33...|         0.0|\n",
      "+--------------------+------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_df = outputFixed.select('features', 'PrivateIndex')\n",
    "final_df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = final_df.randomSplit([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import (DecisionTreeClassifier, RandomForestClassifier, \n",
    "                                      GBTClassifier)\n",
    "from pyspark.ml import Pipeline\n",
    "\n",
    "dt = DecisionTreeClassifier(labelCol = 'PrivateIndex', featuresCol = 'features')\n",
    "rf = RandomForestClassifier(labelCol = 'PrivateIndex', featuresCol = 'features')\n",
    "gb = GBTClassifier(labelCol = 'PrivateIndex', featuresCol = 'features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt_model = dt.fit(train)\n",
    "rf_model = rf.fit(train)\n",
    "gb_model = gb.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt_predictions = dt_model.transform(test)\n",
    "rf_predictions = rf_model.transform(test)\n",
    "gb_predictions = gb_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decision Tree: 0.9042764345001805\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "binary_evaluator = BinaryClassificationEvaluator(labelCol = 'PrivateIndex')\n",
    "\n",
    "print('Decision Tree:', binary_evaluator.evaluate(dt_predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Forest: 0.9822717430530494\n"
     ]
    }
   ],
   "source": [
    "print('Random Forest:' , binary_evaluator.evaluate(rf_predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient-boosted Trees: 0.9343648502345726\n"
     ]
    }
   ],
   "source": [
    "print('Gradient-boosted Trees:', binary_evaluator.evaluate(gb_predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decision Tree Accu: 0.9047619047619048\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "\n",
    "multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'PrivateIndex', metricName = 'accuracy')\n",
    "print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'accuracy'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multi_evaluator.getMetricName()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Forest Accu: 0.935064935064935\n"
     ]
    }
   ],
   "source": [
    "print('Random Forest Accu:', multi_evaluator.evaluate(rf_predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient-boosted Trees Accu: 0.9090909090909091\n"
     ]
    }
   ],
   "source": [
    "print('Gradient-boosted Trees Accu:', multi_evaluator.evaluate(gb_predictions))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
